#include "statsxx/machine_learning/NeuralNet.hpp"

// STL
#include <algorithm>               // std::sort
#include <cmath>                   // sqrt
#include <iostream>                // std::cout
#include <limits>                  // std::numeric_limits<>
#include <utility>                 // std::pair

// jSCIENCE
#include "jScience/stl/vector.hpp" // magnitude()
#include "jutility.hpp"            // get_n_rand_unique_elements()
#include "jstats.hpp"              // mean(), stddev()

// stats++
#include "statsxx/dataset.hpp"     // partition_dataset()
#include "statsxx/machine_learning/NeuralNet/utility.hpp" // fit_validation_error()


//
// DESC: Trains a neural network using the scaled conjugate gradient method.
//
// INPUT:
//          TS              : training set
//          lambda_k        : initial lambda value (0 < lambda0 < 10E-6)
//          sigma           : initial sigma (0 < sigma < 10^-4)
//
// NOTES:
//     ! this algorithm is a COMBINATION of the 1993 Moeller paper (the original method) and from the IMB website:
//          http://publib.boulder.ibm.com/infocenter/spssstat/v20r0m0/index.jsp?topic=%2Fcom.ibm.spss.statistics.help%2Falg_mlp_training_scaled-conjugate-gradients.htm
//     ! the suggested initial values above were taken from the Moeller paper
//     ! there is one slight issue regarding the implementation of this algorithm, regarding a mu vs mu^2 in the denominator of a fraction, since IBM and the Moeller paper differ -- I think this is straightened out though (see the notes below)
//     ! if the problem is a binary classification problem, the output layer should consist of logistic activation functions, and if a multinomial, they should be the softmax functions -- if they are not and classification is set, this routine may (is bound) to fail
//     !!! JMM perhaps an error check of the above should be put in
//     ! on a classification problem, we do NOT need to scale the output, assuming the user assigned 0,1 values correctly
//
inline void NEURAL_NET::SCG(
                            const int      MINITER,
                            const int      MAXITER,
                            // -----
                            const bool     early_stopping,
                            // -----
                            double         lambda_k,
                            const double   sigma,
                            // -----
                            const DataSet &TS_tr,
                            const DataSet &TS_val,
                            // -----
                            const double   CONVG_ITERFRAC,
                            const double   rkTOL,
                            // -----
                            const bool     silent
                            )
{
//    int MINITER            = 10;
//    int MAXITER            = 100;

//    double CONVG_ITERFRAC  = 0.25;
//    double rkTOL           = 1.0E-6;      // EXIT CONDITION (VALUE FROM THE IBM WEBSITE)

    //*********************************************************

    //=========================================================
    // SETUP ALL VECTORS NEEDED (w, wtheta, p, r, sigma, s, delta), & INITIALIZE TO 0
    //=========================================================

    // FLAGS
    int exit_condition;
    bool success;
    // SCG SCALARS & VECTORS
    std::vector<double> w, p_w, s_w, r_w, rk_w;
    //std::vector<double> w, w_best, p_w, s_w, r_w, rk_w;
    double sigma_k, lambdaBar_k, delta_k, mu_k, alpha_k;
    // WORK VECTORS
    std::vector<double> work_w1, work_w2;
    // CONVG/STORAGE
    std::vector<double> Etr, Etr_stddev, Ev, Ev_stddev;
    std::vector<std::vector<double>> w_store;

//    double min_err;
//    int min_err_pt;

    w.resize(m_links.size(), 0.0);
//    w_best.resize(m_links.size(), 0.0);
    p_w.resize(m_links.size(), 0.0);
    s_w.resize(m_links.size(), 0.0);
    r_w.resize(m_links.size(), 0.0);
    rk_w.resize(m_links.size(), 0.0);

    work_w1.resize(m_links.size(), 0.0);
    work_w2.resize(m_links.size(), 0.0);

    std::ofstream file_err;
    if( !silent ) { file_err.open("./training_errs.dat", std::ios::app); }


    //=========================================================
    // ASSIGN w AND w_best
    //=========================================================

    for( decltype(m_links.size()) i = 0; i < m_links.size(); ++i )
    {
        w[i] = m_links[i].weight;
    }

//    w_best = w;


    //=========================================================
    // STEP 1. CHOOSE WEIGHT VECTOR & SCALARS & SET INITIAL VECTORS
    //=========================================================
    // ! lambda_k and sigma have both been passed already set to this routine

    lambdaBar_k = 0.0;

    // CALCULATE E'(w1) ...
    assign_weights(w);
    calulate_Ederiv( TS_tr, work_w1 );
    // ... AND SET p1 = r1 = -E'(w1)
    for( decltype(m_links.size()) i = 0; i < m_links.size(); ++i )
    {
        p_w[i] = r_w[i] = -work_w1[i];
    }

    success = true;

//    min_err    = std::numeric_limits<double>::max();
//    min_err_pt = 0;

    exit_condition = 0;

    // ****************************
    //DataSet TS_full = TS_tr;
    // ****************************

    // MAIN ITERATION LOOP ....
    for(int ITER = 1; ITER <= MAXITER; ++ITER)
    {
        // ****************************
        //TS_tr = DataSet();
        //TS_tr.pt = get_n_rand_unique_elements(500, TS_full.pt);
        // ****************************

        std::cout << "starting ITER: " << ITER << std::endl;

        double pk2, pkmag, rk2, rk1rk;
        std::vector<std::vector<double>> NN_out;

        // CALCULATE |p_k| & |p_k|^2 ...
        pk2 = 0.0;

        for( decltype(m_links.size()) i = 0; i < m_links.size(); ++i )
        {
            pk2 += p_w[i]*p_w[i];
        }

        pkmag = sqrt( pk2 );


        //=========================================================
        // STEP 2. IF SUCCESS == TRUE, THEN CALCULATE SECOND-ORDER INFORMATION
        //=========================================================

        if( success )
        {
            sigma_k = sigma/pkmag;


            // CALCULATE work1 = (w_k + sigma_k*p_k) ...
            for( decltype(m_links.size()) i = 0; i < m_links.size(); ++i )
            {
                work_w1[i] = w[i] + sigma_k*p_w[i];
            }

            // CALCULATE work2 = E'(w_k + sigma_k*p_k) ...
            assign_weights(work_w1);
            calulate_Ederiv( TS_tr, work_w2 );

            // CALCULATE work1 = E'(w_k) ...
            assign_weights(w);
            calulate_Ederiv( TS_tr, work_w1 );

            // NOW CALCULATE s_k = (E'(w_k + sigma_k*p_k) - E'(w_k))/sigma_k ...
            for( decltype(m_links.size()) i = 0; i < m_links.size(); ++i )
            {
                s_w[i] = (work_w2[i] - work_w1[i])/sigma_k;
            }


            // CALCULATE delta_k = (p_k^T)*s_k ...
            delta_k = 0.0;
            for( decltype(m_links.size()) i = 0; i < m_links.size(); ++i )
            {
                delta_k += p_w[i]*s_w[i];
            }
        }


        //=========================================================
        // STEP 3. SCALE delta_k: delta_k = delta_k + (lambda_k - lambdaBar_k)*|p_k|^2
        //=========================================================

        delta_k += (lambda_k - lambdaBar_k)*pk2;


        //=========================================================
        // STEP 4. IF delta_k <= 0, THEN MAKE THE HESSIAN MATRIX POSITIVE DEFINITE
        //=========================================================

        if( delta_k <= 0.0 )
        {
            lambdaBar_k = 2.0*( lambda_k - delta_k/pk2 );

            delta_k     = -delta_k + lambda_k*pk2;

            lambda_k    = lambdaBar_k;
        }


        //=========================================================
        // STEP 5. CALCULATE STEP SIZE
        //=========================================================

        // FIRST CALCULATE mu_k = (p_k^T)*r_k ...
        mu_k = 0.0;
        for( decltype(m_links.size()) i = 0; i < m_links.size(); ++i )
        {
            mu_k += p_w[i]*r_w[i];
        }

        alpha_k = mu_k/delta_k;


        //=========================================================
        // STEP 6. CALCULATE COMPARISON PARAMETER comp_param = (2.0*delta_k*(E(w_k) - E(w_k + alpha_k*p_k)))/mu_k^2
        //=========================================================
        // << JMM >> :: The IBM website only has a mu_k in the denominator (not squared), whereas the Moeller paper has a mu^2. Further, testing time and time again seems to show that mu_k seems to give better results (training can stall by accelerated step reduction with mu^2), so I will stick with mu here

        // FIRST CALCULATE work1 = (w_k + alpha_k*p_k) ...
        for( decltype(m_links.size()) i = 0; i < m_links.size(); ++i )
        {
            work_w1[i] = w[i] + alpha_k*p_w[i];
        }

        assign_weights(w);

        double w0error;
        double w0stddev;
        std::tie( w0error, w0stddev, NN_out) = parse_TS(TS_tr);

        assign_weights(work_w1);

        double w1error;
        double w1stddev;
        std::tie( w1error, w1stddev, NN_out) = parse_TS(TS_tr);

        double comp_param = (2.0*delta_k*(w0error - w1error))/mu_k;
        //double comp_param = (2.0*delta_k*(w0error - w1error))/(mu_k*mu_k);


        //=========================================================
        // STEP 7. IF comp_param >= 0, THEN A SUCCESSFUL REDUCTION IN ERROR CAN BE MADE
        //=========================================================

        if( comp_param >= 0.0 )
        {
            for(auto i = 0; i < w.size(); ++i)
            {
                std::cout << "w[i]: " << w[i] << "   " << work_w1[i] << '\n';
            }

            // ASSIGN w_{k+1} = w_k + alpha_k*p_k (WHICH IS STILL THE work1 VECTOR)
            w = work_w1;

            // CALCULATE work1 = E'(w_k+1), STORING IN work1 VECTORS
            assign_weights(w);
            calulate_Ederiv( TS_tr, work_w1 );

            // FIRST STORE r_k BEFORE UPDATING TO k+1, IF WE NEED IT BELOW
            if( (ITER % m_links.size()) != 0 )
            {
                rk_w = r_w;
            }

            // ASSIGN r_{k+1} = -E'(w_{k+1})
            for( decltype(m_links.size()) i = 0; i < m_links.size(); ++i )
            {
                r_w[i] = -work_w1[i];
            }

            // CHECK FOR CONVERGENCE (AND CALCULATE |r_k|^2 AS A BY-PRODUCT)
            // ! this is out of order from the Moeller paper, but there is no reason NOT to check for convergence here -- although, we do miss the outputting of error at this step
            rk2 = 0.0;
            for( decltype(m_links.size()) i = 0; i < m_links.size(); ++i )
            {
                rk2 += r_w[i]*r_w[i];
            }

            if( sqrt(rk2) < rkTOL )
            {
                exit_condition = 1;
                break;
            }

            lambdaBar_k = 0.0;

            success = true;


            //=========================================================
            // STEP 7a. IF (k % N == 0) ...
            //=========================================================

            // ... RESTART ALGOIRTHM ...
            if( (ITER % m_links.size()) == 0 )
            {
                // ASSIGN p_{k+1} = r_{k+1}
                p_w     = r_w;

                // << JMM >> :: is it necessary to take any further action (e.g., go back to start)?
            }
            // ... ELSE CREATE A NEW CONJUGATE DIRECTION
            else
            {
                // CALCULATE (r_{k+1}^T)*r_k (WE SHOULD HAVE STORED r_k ABOVE, AND ALSO CALCULATED |r_{k+1}|^2) ...
                rk1rk = 0.0;
                for( decltype(m_links.size()) i = 0; i < m_links.size(); ++i )
                {
                    rk1rk += r_w[i]*rk_w[i];
                }

                // NOW CALCULATE beta_k = (|r_{k+1}|^2 - (r_{k+1}^T)*r_k)/mu_k ...
                double beta_k = (rk2 - rk1rk)/mu_k;

                // NOW CALCULATE p_{k+1} = r_{k+1} + beta_k*p_k ...
                for( decltype(m_links.size()) i = 0; i < m_links.size(); ++i )
                {
                    p_w[i] = r_w[i] + beta_k*p_w[i];
                }
            }

            //=========================================================
            // STEP 7b. IF comp_param >= 0.75, REDUCE THE SCALE PARAMETER: lambda_k = (1/4)*lambda_k
            //=========================================================

            if( comp_param >= 0.75 )
            {
                lambda_k *= 0.25;
            }
        }
        // ... ELSE A REDUCTION IN ERROR IS NOT POSSIBLE
        else
        {
            lambdaBar_k = lambda_k;

            success     = false;
        } // end if( comp_param >= 0.0 )


        //=========================================================
        // STEP 8. IF comp_param < 0.25, THEN INCREASE THE SCALE PARAMETER: lambda_k = lambda_k + (delta_k*(1 - comp_param))/|p_k|^2
        //=========================================================
        // ! note that |p_k|^2 is at time k, but p is now stored at k+1, BUT we still have pk2 stored

        if( comp_param < 0.25 )
        {
            lambda_k += delta_k*(1.0 - comp_param)/pk2;
        }


        //=========================================================
        // STEP 9. CHECK FOR STOPPING CRITERIA
        //=========================================================

        std::cout << "    success: " << success << std::endl;

        if( success )
        {
            double Emean;
            double Estddev;

            std::tie( Emean, Estddev, NN_out) = parse_TS(TS_tr);
            Etr.push_back(Emean);
            Etr_stddev.push_back(Estddev);

            std::tie( Emean, Estddev, NN_out) = parse_TS(TS_val);
            Ev.push_back(Emean);
            Ev_stddev.push_back(Estddev);

            if( !silent )
            {
                // the output for this starts at 1
                file_err << Etr.size() << "   " << Etr.back() << "   " << Etr_stddev.back() << "   " << Ev.back() << "   " << Ev_stddev.back() << std::endl;
            }

            w_store.push_back(w);

/*
            if( err_val <= min_err ) // the `=' is used here to ensure that the minimum error is assigned to the highest point
            {
                min_err    = err_val;
                min_err_pt = ITER;

                w_best     = w;
            }
*/
/*
            // << JMM: need a better stopping criterion >>
            if( ((static_cast<double>(min_err_pt)/static_cast<double>(err.size())) < CONVG_ITERFRAC) && (ITER >= MINITER) )
            {
                exit_condition = 2;
                break;
            }
*/
        }
    } // ++ITER


    std::cout << "Exit condition of CNEURAL_NET::SCG was " << exit_condition;
    switch( exit_condition )
    {
        case 0:
            std::cout << " (too many iterations)" << std::endl;
            break;
        case 1:
            std::cout << " [sqrt(rk2) < rkTOL]" << std::endl;
            break;
        case 2:
            std::cout << " (min. error occured in first CONVG_ITERFRAC of iterations -- best exit type)" << std::endl;
            break;
        default:
            break;
    }

    if( early_stopping )
    {
        // NOTE: See NOTEs in ::backpropagation().

        std::vector<double>::size_type indx;
        std::vector<double>            f;
        std::tie(
                 f,
                 indx
                 ) = fit_validation_error(
                                          Ev,
                                          Ev_stddev
                                          );

        std::ofstream ofs2("fit_valid_err.dat");

        ofs2 << "using point: " << indx << '\n';

        for( auto i = 0; i < f.size(); ++i )
        {
            ofs2 << (i+1) << "   " << f[i] << '\n';
        }

        ofs2.close();

        assign_weights(w_store[indx]);
    }
    else
    {
        assign_weights(w_store.back());
    }

    // ***
/*
    std::vector<double> dEdw;
    this->calulate_Ederiv(TS_tr, dEdw);

    double lambda = magnitude(dEdw)/(2.0*magnitude(w_store[indx]));

    std::cout << "lambda: " << lambda << std::endl;
    int oo;
    std::cin >> oo;
*/
    // ***



//    assign_weights(w_best);


    if( !silent )
    {
        // LEAVE ONE EXTRA LINE IN OUTPUT FILE TO SEPARATE DATA BEFORE CLOSING...
        file_err << std::endl;
        file_err.close();
    }
}
